-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Welford Scheduling Support #561
Conversation
…into welford_rebase
@@ -918,6 +970,15 @@ generateIndexAndExtentMap( | |||
loops.pop_back(); | |||
} | |||
|
|||
if (tv->definition()->isA<WelfordOp>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this assume WelfordOp
is the only expression type with multiple outputs? If so, would it be possible to generalize it so that it could work with any future expressions with multiple outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. WelfordOp
is currently the only multi-output use case we considered so far. This PR was trying to support WelfordOp
with minimal generalizations but if we have other multi-output cases we can generalize.
On the index compute side the implementation is more temporary than architectural due to the limitation that the loop variable now can only be mapped to one of the outputs. This part will be re-factored after we switch to index compute based on local compute-at and domain maps (@csarofeen). I'd prefer adding multi-output support at that point if we do decide to generalize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay to start with something specific to Welford and generalize it later, as long as it is guarded with assertion about the assumption. For example, if something is only meant to work with Welford, then it should be preceded by TORCH_INTERNAL_ASSERT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added TORCH_INTERNAL_ASSERT for the multioutput case. Thanks.
Thanks for the detailed review and helpful suggestions! 👍 |
torch/csrc/jit/codegen/cuda/arith.h
Outdated
@@ -46,6 +46,31 @@ TORCH_CUDA_API TensorView* reductionOp( | |||
TensorView* v1, | |||
bool keep_dim = false); | |||
|
|||
//! Auxiliary Struct holding result of | |||
//! a single welford op in ternsorview | |||
struct TORCH_CUDA_API WelfordResult { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: use class
instead of struct
for anything with methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Gave a couple of minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just one comment on rfactor I'd like to see addressed then can approve.
namespace { | ||
|
||
template <typename T> | ||
kir::Allocate* allocGlobalBuffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use this to also simplify the grid reduction code? Would make more sense to do in a follow up if yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think so. Thanks for pointing this out. I will put further simplifications in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR intends to implement welford scheduling into the fuser pipeline.
(This PR will introduce non-trivial merging with #586 , I will rebase after #586 has merged)
More scheduling tests and welford scheduler in a subsequent PR.
Example math print containing welford:
Example kernel containing welford: